{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import logging\n",
    "\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_path = 'alxlnet-base-2020-04-10/model.ckpt-300000'\n",
    "init_vars = tf.train.list_variables(tf_path)\n",
    "\n",
    "tf_weights = {}\n",
    "for name, shape in init_vars:\n",
    "    logger.info(\"Loading TF weight {} with shape {}\".format(name, shape))\n",
    "    array = tf.train.load_variable(tf_path, name)\n",
    "    tf_weights[name] = array"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(32000, 128)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf_weights['model/transformer/word_embedding/lookup_table'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(128, 768)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf_weights['model/transformer/word_embedding/lookup_table_2'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(32000,)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf_weights['model/lm_loss/bias'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['beta1_power', 'beta2_power', 'global_step', 'model/lm_loss/bias', 'model/lm_loss/bias/Adam', 'model/lm_loss/bias/Adam_1', 'model/transformer/layer_shared/ff/LayerNorm/beta', 'model/transformer/layer_shared/ff/LayerNorm/beta/Adam', 'model/transformer/layer_shared/ff/LayerNorm/beta/Adam_1', 'model/transformer/layer_shared/ff/LayerNorm/gamma', 'model/transformer/layer_shared/ff/LayerNorm/gamma/Adam', 'model/transformer/layer_shared/ff/LayerNorm/gamma/Adam_1', 'model/transformer/layer_shared/ff/layer_1/bias', 'model/transformer/layer_shared/ff/layer_1/bias/Adam', 'model/transformer/layer_shared/ff/layer_1/bias/Adam_1', 'model/transformer/layer_shared/ff/layer_1/kernel', 'model/transformer/layer_shared/ff/layer_1/kernel/Adam', 'model/transformer/layer_shared/ff/layer_1/kernel/Adam_1', 'model/transformer/layer_shared/ff/layer_2/bias', 'model/transformer/layer_shared/ff/layer_2/bias/Adam', 'model/transformer/layer_shared/ff/layer_2/bias/Adam_1', 'model/transformer/layer_shared/ff/layer_2/kernel', 'model/transformer/layer_shared/ff/layer_2/kernel/Adam', 'model/transformer/layer_shared/ff/layer_2/kernel/Adam_1', 'model/transformer/layer_shared/rel_attn/LayerNorm/beta', 'model/transformer/layer_shared/rel_attn/LayerNorm/beta/Adam', 'model/transformer/layer_shared/rel_attn/LayerNorm/beta/Adam_1', 'model/transformer/layer_shared/rel_attn/LayerNorm/gamma', 'model/transformer/layer_shared/rel_attn/LayerNorm/gamma/Adam', 'model/transformer/layer_shared/rel_attn/LayerNorm/gamma/Adam_1', 'model/transformer/layer_shared/rel_attn/k/kernel', 'model/transformer/layer_shared/rel_attn/k/kernel/Adam', 'model/transformer/layer_shared/rel_attn/k/kernel/Adam_1', 'model/transformer/layer_shared/rel_attn/o/kernel', 'model/transformer/layer_shared/rel_attn/o/kernel/Adam', 'model/transformer/layer_shared/rel_attn/o/kernel/Adam_1', 'model/transformer/layer_shared/rel_attn/q/kernel', 'model/transformer/layer_shared/rel_attn/q/kernel/Adam', 'model/transformer/layer_shared/rel_attn/q/kernel/Adam_1', 'model/transformer/layer_shared/rel_attn/r/kernel', 'model/transformer/layer_shared/rel_attn/r/kernel/Adam', 'model/transformer/layer_shared/rel_attn/r/kernel/Adam_1', 'model/transformer/layer_shared/rel_attn/v/kernel', 'model/transformer/layer_shared/rel_attn/v/kernel/Adam', 'model/transformer/layer_shared/rel_attn/v/kernel/Adam_1', 'model/transformer/mask_emb/mask_emb', 'model/transformer/mask_emb/mask_emb/Adam', 'model/transformer/mask_emb/mask_emb/Adam_1', 'model/transformer/r_r_bias', 'model/transformer/r_r_bias/Adam', 'model/transformer/r_r_bias/Adam_1', 'model/transformer/r_s_bias', 'model/transformer/r_s_bias/Adam', 'model/transformer/r_s_bias/Adam_1', 'model/transformer/r_w_bias', 'model/transformer/r_w_bias/Adam', 'model/transformer/r_w_bias/Adam_1', 'model/transformer/seg_embed', 'model/transformer/seg_embed/Adam', 'model/transformer/seg_embed/Adam_1', 'model/transformer/word_embedding/lookup_table', 'model/transformer/word_embedding/lookup_table/Adam', 'model/transformer/word_embedding/lookup_table/Adam_1', 'model/transformer/word_embedding/lookup_table_2', 'model/transformer/word_embedding/lookup_table_2/Adam', 'model/transformer/word_embedding/lookup_table_2/Adam_1'])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf_weights.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import XLNetConfig\n",
    "\n",
    "config = XLNetConfig.from_json_file('alxlnet-base-2020-04-10/config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modeling_alxlnet import XLNetLMHeadModel, load_tf_weights_in_xlnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = XLNetLMHeadModel(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "XLNetLMHeadModel(\n",
       "  (transformer): XLNetModel(\n",
       "    (word_embedding): Embedding(32000, 128)\n",
       "    (word_embedding2): Embedding(128, 768)\n",
       "    (layer): ModuleList(\n",
       "      (0): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (1): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (2): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (3): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (4): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (5): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (6): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (7): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (8): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (9): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (10): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (11): XLNetLayer(\n",
       "        (rel_attn): XLNetRelativeAttention(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ff): XLNetFeedForward(\n",
       "          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layer_1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (layer_2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (lm_loss): Linear(in_features=768, out_features=32000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "load_tf_weights_in_xlnet(model, config, tf_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir alxlnet-base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import CONFIG_NAME, WEIGHTS_NAME\n",
    "import os\n",
    "import torch\n",
    "\n",
    "pytorch_weights_dump_path = os.path.join('alxlnet-base', WEIGHTS_NAME)\n",
    "pytorch_config_dump_path = os.path.join('alxlnet-base', CONFIG_NAME)\n",
    "torch.save(model.state_dict(), pytorch_weights_dump_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(pytorch_config_dump_path, \"w\", encoding=\"utf-8\") as f:\n",
    "    f.write(config.to_json_string())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modeling_alxlnet import XLNetModel\n",
    "from transformers import XLNetTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('alxlnet-base/spiece.model',\n",
       " 'alxlnet-base/special_tokens_map.json',\n",
       " 'alxlnet-base/added_tokens.json')"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = XLNetTokenizer('sp10m.cased.v9.model', do_lower_case = False)\n",
    "tokenizer.save_pretrained('alxlnet-base')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = XLNetTokenizer.from_pretrained('./alxlnet-base', do_lower_case = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "XLNetConfig {\n",
       "  \"attn_type\": \"bi\",\n",
       "  \"bi_data\": false,\n",
       "  \"bos_token_id\": 1,\n",
       "  \"clamp_len\": -1,\n",
       "  \"d_head\": 64,\n",
       "  \"d_inner\": 3072,\n",
       "  \"d_model\": 768,\n",
       "  \"dropout\": 0.1,\n",
       "  \"end_n_top\": 5,\n",
       "  \"eos_token_id\": 2,\n",
       "  \"ff_activation\": \"gelu\",\n",
       "  \"initializer_range\": 0.02,\n",
       "  \"layer_norm_eps\": 1e-12,\n",
       "  \"mem_len\": null,\n",
       "  \"model_type\": \"xlnet\",\n",
       "  \"n_head\": 12,\n",
       "  \"n_layer\": 12,\n",
       "  \"pad_token_id\": 5,\n",
       "  \"reuse_len\": null,\n",
       "  \"same_length\": false,\n",
       "  \"start_n_top\": 5,\n",
       "  \"summary_activation\": \"tanh\",\n",
       "  \"summary_last_dropout\": 0.1,\n",
       "  \"summary_type\": \"last\",\n",
       "  \"summary_use_proj\": true,\n",
       "  \"untie_r\": true,\n",
       "  \"vocab_size\": 32000\n",
       "}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config = XLNetConfig.from_json_file('alxlnet-base-2020-04-10/config.json')\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = XLNetModel.from_pretrained('./alxlnet-base', config = config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = torch.tensor([tokenizer.encode(\"husein tk suka mkan ayam\", add_special_tokens=True)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[ 0.0969,  0.7955, -0.3110,  ..., -0.4857,  1.0827, -0.8285],\n",
       "          [ 0.5202,  0.3320, -0.2579,  ..., -1.6332, -0.8103, -1.0283],\n",
       "          [-1.2109,  0.0866, -0.7792,  ..., -0.5070, -1.6178, -1.2713],\n",
       "          ...,\n",
       "          [-0.9953,  0.7751, -0.7880,  ..., -0.7724, -0.2612,  0.6644],\n",
       "          [ 0.3525, -0.6217, -1.0145,  ..., -1.4954, -1.7150, -0.2340],\n",
       "          [-0.2612, -0.4044, -0.7142,  ...,  0.2696, -0.6073,  0.5948]]],\n",
       "        grad_fn=<PermuteBackward>),)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(input_ids, attention_mask = torch.ones(input_ids.size()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
